{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 通过极简方案构建手写数字识别模型\n",
    "\n",
    "上一节介绍了创新性的“横纵式”教学法，有助于深度学习初学者快速掌握深度学习理论知识，并在过程中让读者获得到真实建模的实战体验。在“横纵式”教学法中，纵向概要介绍模型的基本代码结构和极简实现方案，如 **图1** 所示。本节将使用这种极简实现方案快速完成手写数字识别的建模。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/45b5af558f704356a1e1ba763ae35955ffa5138b6f5e400e98e710caef7b1a71\" width=\"800\" hegiht=\"\" ></center>\n",
    "<center><br>图1：“横纵式”教学法—纵向极简实现方案</br></center>\n",
    "<br></br>\n",
    "\n",
    "### 前提条件\n",
    "\n",
    "在数据处理前，首先要加载飞桨平台与“手写数字识别”模型相关类库，代码如下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#加载飞桨和相关类库\n",
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph import Linear  # 如果是之定义一层的网络，这边导入的Linear\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 数据处理\n",
    "\n",
    "飞桨提供了多个封装好的数据集API，涵盖计算机视觉、自然语言处理、推荐系统等多个领域，帮助读者快速完成机器学习任务。如在手写数字识别模型中，通过[paddle.dataset.mnist.train()](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/data/dataset_cn.html)可以直接获取处理好的MNIST训练集、验证集和测试集，飞桨API支持如下常见的学术数据集：\n",
    "\n",
    "* mnist\n",
    "* cifar\n",
    "* Conll05\n",
    "* imdb\n",
    "* imikolov\n",
    "* movielens\n",
    "* sentiment\n",
    "* uci_housing\n",
    "* wmt14\n",
    "* wmt16\n",
    "\n",
    "通过paddle.dataset.mnist.train()函数设置数据读取器，batch_size设置为8，即一个批次有8张图片和8个标签，代码如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz \n",
      "Begin to download\n",
      "\n",
      "Download finished\n",
      "Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz \n",
      "Begin to download\n",
      "........\n",
      "Download finished\n"
     ]
    }
   ],
   "source": [
    "# 如果～/.cache/paddle/dataset/mnist/目录下没有MNIST数据，API会自动将MINST数据下载到该文件夹下\n",
    "# 设置数据读取器，读取MNIST数据训练集\n",
    "trainset = paddle.dataset.mnist.train()\n",
    "# 包装数据读取器，每次读取的数据数量设置为batch_size=8\n",
    "train_reader = paddle.batch(trainset, batch_size=8)  # 每次读取8张图片"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "paddle.batch函数将MNIST数据集拆分成多个批次，通过如下代码读取第一个批次的数据内容，观察数据打印结果。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图像数据形状和对应数据为: (8, 784) [-1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -0.9764706  -0.85882354 -0.85882354 -0.85882354\n",
      " -0.01176471  0.06666672  0.37254906 -0.79607844  0.30196083  1.\n",
      "  0.9372549  -0.00392157 -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -0.7647059  -0.7176471  -0.26274508  0.20784318\n",
      "  0.33333337  0.9843137   0.9843137   0.9843137   0.9843137   0.9843137\n",
      "  0.7647059   0.34901965  0.9843137   0.8980392   0.5294118  -0.4980392\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -0.6156863\n",
      "  0.8666667   0.9843137   0.9843137   0.9843137   0.9843137   0.9843137\n",
      "  0.9843137   0.9843137   0.9843137   0.96862745 -0.27058822 -0.35686272\n",
      " -0.35686272 -0.56078434 -0.69411767 -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -0.85882354  0.7176471   0.9843137\n",
      "  0.9843137   0.9843137   0.9843137   0.9843137   0.5529412   0.427451\n",
      "  0.9372549   0.8901961  -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -0.372549    0.22352946 -0.1607843   0.9843137\n",
      "  0.9843137   0.60784316 -0.9137255  -1.         -0.6627451   0.20784318\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -0.8901961  -0.99215686  0.20784318  0.9843137  -0.29411763\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.          0.09019613  0.9843137   0.4901961  -0.9843137  -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -0.9137255\n",
      "  0.4901961   0.9843137  -0.45098037 -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -0.7254902   0.8901961\n",
      "  0.7647059   0.254902   -0.15294117 -0.99215686 -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -0.36470586  0.88235295  0.9843137\n",
      "  0.9843137  -0.06666666 -0.8039216  -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -0.64705884  0.45882356  0.9843137   0.9843137\n",
      "  0.17647064 -0.7882353  -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -0.8745098  -0.27058822  0.9764706   0.9843137   0.4666667\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.          0.9529412   0.9843137   0.9529412  -0.4980392  -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -0.6392157   0.0196079   0.43529415  0.9843137\n",
      "  0.9843137   0.62352943 -0.9843137  -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -0.69411767  0.16078436\n",
      "  0.79607844  0.9843137   0.9843137   0.9843137   0.9607843   0.427451\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -0.8117647  -0.10588235  0.73333335  0.9843137   0.9843137   0.9843137\n",
      "  0.9843137   0.5764706  -0.38823527 -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -0.81960785 -0.4823529   0.67058825  0.9843137\n",
      "  0.9843137   0.9843137   0.9843137   0.5529412  -0.36470586 -0.9843137\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -0.85882354  0.3411765\n",
      "  0.7176471   0.9843137   0.9843137   0.9843137   0.9843137   0.5294118\n",
      " -0.372549   -0.92941177 -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -0.5686275   0.34901965  0.77254903  0.9843137   0.9843137   0.9843137\n",
      "  0.9843137   0.9137255   0.04313731 -0.9137255  -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.          0.06666672  0.9843137\n",
      "  0.9843137   0.9843137   0.6627451   0.05882359  0.03529418 -0.8745098\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.         -1.         -1.\n",
      " -1.         -1.         -1.         -1.        ]\n",
      "图像标签形状和对应数据为: (8,) 5.0\n",
      "\n",
      "打印第一个batch的第一个图像，对应标签数字为5.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-04-07 19:25:18,789-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']\n",
      "2020-04-07 19:25:19,222-INFO: generated new fontManager\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 以迭代的形式读取数据\n",
    "for batch_id, data in enumerate(train_reader()):  # 因为每次读取8张图片，这里循环一次data就表示的是8张图片\n",
    "    # 获得图像数据，并转为float32类型的数组\n",
    "    # 因为图片是以像素的形式存在，这里就是把8张图片的数据读入img_data中，就表示为（8，28*28）的大小\n",
    "    img_data = np.array([x[0] for x in data]).astype('float32')   # x[0]表示像素（就是一张图片）， x[1]表示标签值\n",
    "    # 获得图像标签数据，并转为float32类型的数组\n",
    "    label_data = np.array([x[1] for x in data]).astype('float32')\n",
    "    # 打印数据形状\n",
    "    print(\"图像数据形状和对应数据为:\", img_data.shape, img_data[0])  # img_data是图片的数组img_data[0]表示第一张图片\n",
    "    # 打印图像对应的数据，因为图像是灰度图像，其中-1表示背景颜色，不是-1的表示手写数字的颜色， 因为图片是经过归一化处理过的了\n",
    "    print(\"图像标签形状和对应数据为:\", label_data.shape, label_data[0])\n",
    "    break\n",
    "\n",
    "print(\"\\n打印第一个batch的第一个图像，对应标签数字为{}\".format(label_data[0]))\n",
    "# 显示第一batch的第一个图像\n",
    "import matplotlib.pyplot as plt\n",
    "img = np.array(img_data[0]+1)*127.5\n",
    "img = np.reshape(img, [28, 28]).astype(np.uint8)\n",
    "\n",
    "plt.figure(\"Image\") # 图像窗口名称\n",
    "plt.imshow(img)\n",
    "plt.axis('on') # 关掉坐标轴为 off\n",
    "plt.title('image') # 图像题目\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "从打印结果看，从数据加载器train_loader()中读取一次数据，可以得到形状为（8, 784）的图像数据和形状为（8,）的标签数据。其中，形状中的数字8与设置的batch_size大小对应，784为MINIST数据集中每个图像的像素大小(28\\*28)。\n",
    "\n",
    "此外，从打印的图像数据来看，图像数据的范围是[-1, 1]，表明这是已经完成图像归一化后的图像数据，并且空白背景部分的值是-1。将图像数据反归一化，并使用matplotlib工具包将其显示出来，如 **图2** 所示。图片显示的数字是5，和对应标签数字一致。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/89634a5fc22c46ebb522a924b56b9344e6169a632b9f48898d06e0f799673302\" width=\"300\" hegiht=\"\" ></center>\n",
    "<center><br>图2：matplotlib打印结果示意图</br></center>\n",
    "<br></br>\n",
    "\n",
    "------\n",
    "**说明：**\n",
    "\n",
    "飞桨将维度是28\\*28的手写数字数据图像转成向量形式存储，因此使用飞桨数据读取到的手写数字图像是长度为784（28\\*28）的向量。\n",
    "\n",
    "------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 飞桨API的使用方法\n",
    "\n",
    "熟练掌握飞桨API的使用方法，是使用飞桨构建各类深度学习任务的基础，也是开发者必须掌握的技能，下面介绍飞桨API获取方式和使用方法。\n",
    "\n",
    "**1. 飞桨API文档获取方式**\n",
    "\n",
    "登录“[飞桨官网->文档->API Reference](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/index_cn.html)”，获取飞桨API文档，如 **图3** 所示。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/8e8d0028e7964b6bbd8fccc2639bf38f86270913f4594bf885eeaf839d6e8a43\" width=\"700\" hegiht=\"\" ></center>\n",
    "<center><br>图3：飞桨API文档</br></center>\n",
    "<br></br>\n",
    "\n",
    "**2. 通过搜索和分类浏览两种方式查阅API文档**\n",
    "\n",
    "如果用户知道需要查阅的API名称，可通过页面右上角的**搜索框**，快速获取API。\n",
    "\n",
    "如果想全面了解飞桨API文档内容，也可以在**API Reference**首页，单击“**API功能分类**”，通过概念分类获取不同职能的API，如 **图4** 所示\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/f6833d0191834057a37c44427ddb0c0abd3cb97a5b9a4d2289129b4fc9231140\" width=\"700\" hegiht=\"\" ></center>\n",
    "<center><br>图4：飞桨API功能分类页面</br></center>\n",
    "<br></br>\n",
    "\n",
    "在API功能分类的页面，用户可以根据神经网络建模的逻辑概念来浏览相应部分的API，如优化器、网络层、评价指标、模型保存和加载等。\n",
    "\n",
    "**3. API文档使用方法**\n",
    "\n",
    "飞桨每个API的文档结构一致，包含接口形式、功能说明和计算公式、参数和返回值、代码示例四个部分。 以abs函数为例，API文档结构如 **图5** 所示。通过飞桨API文档，读者不仅可以详细查看函数功能，还可以通过可运行的代码示例来实践API的使用。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/225d39636f384b519e7c926e5872dafaffe4a397af5a4b50aaeb1c12518e8a8b\" width=\"600\" hegiht=\"\" ></center>\n",
    "<center><br>图5：abs函数的API文档</br></center>\n",
    "<br></br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 模型设计\n",
    "\n",
    "在房价预测深度学习任务中，我们使用了单层且没有非线性变换的模型，取得了理想的预测效果。在手写数字识别中，我们依然使用这个模型预测输入的图形数字值。其中，模型的输入为784维（28\\*28）数据，输出为1维数据，如 **图6** 所示。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/65c0a269e7ab49ed995bc603c9b3e5aec31f33f7bf4544cea611b4c09ee861db\" width=\"400\" hegiht=\"\" ></center>\n",
    "<center><br>图6：手写数字识别网络模型</br></center>\n",
    "<br></br>\n",
    "\n",
    "输入像素的位置排布信息对理解图像内容非常重要（如将原始尺寸为28\\*28图像的像素按照7\\*112的尺寸排布，那么其中的数字将不可识别），因此网络的输入设计为28\\*28的尺寸，而不是1\\*784，以便于模型能够正确处理像素之间的空间信息。\n",
    "\n",
    "------\n",
    "**说明：**\n",
    "\n",
    "事实上，采用只有一层的简单网络（对输入求加权和）时并没有处理位置关系信息，因此可以猜测出此模型的预测效果可能有限。在后续优化环节介绍的卷积神经网络则更好的考虑了这种位置关系信息，模型的预测效果也会有显著提升。\n",
    "\n",
    "------\n",
    "\n",
    "下面以类的方式组建手写数字识别的网络，代码如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 定义mnist数据识别网络结构，同房价预测网络\n",
    "class MNIST(fluid.dygraph.Layer):\n",
    "    def __init__(self, name_scope):\n",
    "        super(MNIST, self).__init__(name_scope)\n",
    "        name_scope = self.full_name()\n",
    "        # 定义一层全连接层，输出维度是1，激活函数为None，即不使用激活函数\n",
    "        # self.fc = FC(name_scope, size=1, act=None)  # 因为1.7版本之后这个FC删除了，所以不能使用\n",
    "        self.fc = Linear(input_dim=28*28, output_dim=1, act=None)  # 28*28表示图像的维度大小\n",
    "        \n",
    "    # 定义网络结构的前向计算过程\n",
    "    def forward(self, inputs):\n",
    "        outputs = self.fc(inputs)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 训练配置\n",
    "\n",
    "训练配置需要先生成模型实例（设为“训练”状态），再设置优化算法和学习率（使用随机梯度下降SGD，学习率设置为0.01），代码如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 定义飞桨动态图工作环境\n",
    "with fluid.dygraph.guard():\n",
    "    # 声明网络结构\n",
    "    model = MNIST(\"mnist\")\n",
    "    # 启动训练模式\n",
    "    model.train()\n",
    "    # 定义数据读取函数，数据读取batch_size设置为16\n",
    "    train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16)\n",
    "    # 定义优化器，使用随机梯度下降SGD优化器，学习率设置为0.001\n",
    "    optimizer = fluid.optimizer.SGDOptimizer(parameter_list=model.parameters(), learning_rate=0.001)  # 这边设置优化参数一定要给他的参数parameter_list=model.parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 训练过程\n",
    "\n",
    "训练过程采用二层循环嵌套方式，训练完成后需要保存模型参数，以便后续使用。\n",
    "\n",
    "- 内层循环：负责整个数据集的一次遍历，遍历数据集采用分批次（batch）方式。\n",
    "- 外层循环：定义遍历数据集的次数，本次训练中外层循环10次，通过参数EPOCH_NUM设置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, batch: 1000, loss is: [2.180437]\n",
      "epoch: 0, batch: 2000, loss is: [4.2682085]\n",
      "epoch: 0, batch: 3000, loss is: [3.312719]\n",
      "epoch: 1, batch: 1000, loss is: [1.9611645]\n",
      "epoch: 1, batch: 2000, loss is: [4.0380054]\n",
      "epoch: 1, batch: 3000, loss is: [3.2723315]\n"
     ]
    }
   ],
   "source": [
    "# 通过with语句创建一个dygraph运行的context，\n",
    "# 动态图下的一些操作需要在guard下进行\n",
    "with fluid.dygraph.guard():\n",
    "    model = MNIST(\"mnist\")\n",
    "    model.train()\n",
    "    train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16)\n",
    "    optimizer = fluid.optimizer.SGDOptimizer(parameter_list=model.parameters(), learning_rate=0.001)\n",
    "    EPOCH_NUM = 1000\n",
    "    for epoch_id in range(EPOCH_NUM):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            #准备数据，格式需要转换成符合框架要求的\n",
    "            image_data = np.array([x[0] for x in data]).astype('float32')  # (16, 28*28)\n",
    "            label_data = np.array([x[1] for x in data]).astype('float32').reshape(-1, 1)\n",
    "            # 将数据转为飞桨动态图格式\n",
    "            image = fluid.dygraph.to_variable(image_data)\n",
    "            label = fluid.dygraph.to_variable(label_data)\n",
    "            \n",
    "            #前向计算的过程\n",
    "            predict = model(image)  # 这边会直接调用forward函数\n",
    "            \n",
    "            #计算损失，取一个批次样本损失的平均值\n",
    "            loss = fluid.layers.square_error_cost(predict, label)\n",
    "            avg_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            #每训练了1000批次的数据，打印下当前Loss的情况\n",
    "            if batch_id !=0 and batch_id  % 1000 == 0:\n",
    "                print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.numpy()))\n",
    "            \n",
    "            #后向传播，更新参数的过程\n",
    "            avg_loss.backward()\n",
    "            optimizer.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "\n",
    "    # 保存模型\n",
    "    fluid.save_dygraph(model.state_dict(), 'mnist1000')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "通过观察上述代码可以发现，手写数字识别的代码几乎与房价预测任务一致，如果不是下述读取数据的两行代码有所差异，我们会误认为这是房价预测的模型。\n",
    "\n",
    "```\n",
    "   #准备数据，格式需要转换成符合框架要求的\n",
    "   image_data = np.array([x[0] for x in data]).astype('float32')\n",
    "   label_data = np.array([x[1] for x in data]).astype('float32').reshape(-1, 1)\n",
    "```\n",
    "\n",
    "\n",
    "另外，从训练过程中损失所发生的变化可以发现，虽然损失整体上在降低，但到训练的最后一轮，损失函数值依然较高。可以猜测手写数字识别完全复用房价预测的代码，训练效果并不好。接下来我们通过模型测试，获取模型训练的真实效果。\n",
    "\n",
    "# 模型测试\n",
    "\n",
    "模型测试的主要目的是验证训练好的模型是否能正确识别出数字，包括如下四步：\n",
    "\n",
    "* 声明实例\n",
    "* 加载模型：加载训练过程中保存的模型参数，\n",
    "* 灌入数据：将测试样本传入模型，模型的状态设置为校验状态（eval），显式告诉框架我们接下来只会使用前向计算的流程，不会计算梯度和梯度反向传播。\n",
    "* 获取预测结果，取整后作为预测标签输出。\n",
    "\n",
    "在模型测试之前，需要先从'./demo/example_0.jpg'文件中读取样例图片，并进行归一化处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fd2a1b606d0>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQoAAAD8CAYAAACPd+p5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHIZJREFUeJzt3X14VNWdB/DvLwmhFXE15MWBJA2RQAkQEEZwV8qjUqjNU0GkD8KuGgprXF/Wrktr032Tfao2bdXq1q1rLK7Yuqh9WgmLiKXUFHGXphMIClrLi1EIaSDAqojKS377x9yEId5zb+ZmZu6d5Pt5njwzc86ce385xJ937plzjqgqiIicZPgdABEFHxMFEblioiAiV0wUROSKiYKIXDFREJGrpCUKEblKRN4Skd0iUpOs8xBR8kkyvkchIpkA/ghgFoD9AH4PYJGqvpHwkxFR0iXrimIqgN2quldVTwB4BsDcJJ2LiJIsK0nHHQFgX8zr/QCmmd6cm5urJSUlSQqFiGI1NTV1qGpePG2SlShciUg1gGoAKC4uRiQS8SsUogFFRN6Jt02yPnq0AiiKeV1olXVT1TpVDatqOC8vruRGRCmWrETxewBlIjJSRLIBLASwJknnIqIkS8pHD1U9JSK3A3gJQCaAJ1R1ZzLORUTJl7R7FKq6DsC6ZB2fiFKH38wkIldMFETkiomCiFz59j2KVMnNyURJ0SC/w4Ai/q/KC8TT8ZzaOTEdcyDG4fV4XmJPpJZ9J9Fx5HTCj5uWieL06dMYPXo0NmzYgMLCQuzevRvl5eW27y0pGoTGl4ps61LptHbG3SZTzBd8J9X8xzBIMuM+l9MxnY7n9Hs5xe/EdEyn4yU6Dq/H8xJ7Ik390j73N3mQlh89GhsbMWrUKJSWliI7Oxv19fV+h0TUr6VlomhtbUVRUdFZr3uqq6tDOBzGocOJvwwjGmjSMlH0nBov8unPf9XV1YhEIsgb5u0ynIjOSMtEUVhYiH37znwWGz58uI/REPV/aZkoLrnkEuzatQtvv/02Tpw4gTlz5vgdElG/lpajHllZWdi7d2/363HjxvkYTe8k+q6315GITodhv06Y2pnP5XQ8eBw5SNUIAQAc7zxhW35ORraxTTJGeoKuf/5WRJRQTBRE5IqJgohcMVEQkSsmCiJylZajHgOFl/khgPNIhPM8kPi/nJbhMNnJ61wV0zG9jpQ4ncs0uuF1Lo3fcz2SJb2jJ6KUYKIgIldMFETkiomCiFwxURCRKyYKInLF4dE05TTcdkpPOrWM+1zJWHbPSzuvcTgN4Xo5HieFERHZYKIgIldMFETkiomCiFwxURCRq34/6qHQhE7U8TpRy8u5vN5BHyzmndHePXXMWDfrqW/alpf88xZjm4zBg411nR9/bKz7+OqpxrrK+162Lf/WsF3GNpRcvKIgIldMFETkiomCiFwxURCRqz7dzBSRFgAfADgN4JSqhkUkB8CzAEoAtABYoKpH+xYmEfkpEVcUV6jqJFUNW69rAGxU1TIAG63XRJTGkjE8OhfA5dbzlQAaAHwr0ScpKSnB0KFDkZmZiebm5kQf3shpyNLr0KkXThOkVn1QYKz76c1fM9aN3LLVtrzzzyuMbd6qNv8Jnb/FvNtWXl2jsW79yctty9/5zjBjmx8Of8VY5zRcfKzTfgj33IzPGNv014lfTvqaKBTAr0REATymqnUAClS1zar/EwDzX20fvfzyy8jNzU3W4YnI0tdEMV1VW0UkH8AGEflDbKWqqpVEPkVEqgFUA0BxcXEfwyCiZOrTNZSqtlqPBwE8D2AqgHYRCQGA9XjQ0LZOVcOqGs7Ly4v73CKC2bNnY8qUKZ7jJ6Le8ZwoRGSIiAzteg5gNoAdANYAqLLeVgWgvq9B2nn11VexdetWvPjii9i0adOn6uvq6hAOh9FxOHX3DYj6q75cURQA2Cwi2wE0AnhBVdcDqAUwS0R2Afii9Trhhg8fDgDIz89HY+Onb4xVV1cjEokgd9jAu/FElGie71Go6l4AE23KDwOY2Zeg3Ozduxfz5s0DAJw6dQo7d+5M5umIBry0nD1aWlqK7du3+3LuVK6X+InD2pe37rvCWHfg5iJjXebO183t7rCf0fnsHfcb24zNPsdYh9nmqpGTqo11n//7HbblLVeZz/WdjZONdffkm3/nz4p5CNeEa2YSEdlgoiAiV0wUROSKiYKIXDFREJGrtBz1iIdAUro2phem0Y3jneZRjz9+b5yxbsgb24x1H8++2Fj36zt/YFuenznE2MbrCEDWe+aduDo/+si2PEPMO3498+vLjHV3LzJPGvSyi1h/HdlwMvB+YyKKGxMFEblioiAiV0wUROSKiYKIXDFREJGrfj886nVLwUQPgTkNI5rWdCz/xd8a24yq/72xTsPlxrp/+/GPjHWmYdBkTIJav9B+KBYAvvruXbbloefMWwqW/ctrxrqLS2801m0IP2Z/rqxzjW2Od54w1p2TEf8ks3TAKwoicsVEQUSumCiIyBUTBRG5YqIgIldMFETkqt8Pj3qdPWritI6l09Z1nbDdBwkAYJpHWfr8J8Y2Msj8T/fWX5u3w6vINtd54bS1oZOLBpmHHyP/8IhteXj2XxrbXHi9/daAAFC4aLex7sqf3WJb/uZlPzW2cRoCNW1R6GVtziDNUg1OJEQUWEwUROSKiYKIXDFREJErJgoicsVEQUSuAj88umTJEqxduxb5+fnYsWMHjhw5guuuuw4tLS0oKSnBc889hwsuuCBl8TgNgTrNsnRaxPW5Y39mW579zmFjm0MLzVvovXKVeWYmYB6WNA3tnZthHlI1L5HrnWlYcNslzxjbPPi/pca6X03JM9aNrN5vW175y0pjm3Vj1hnrnPoqnQX+imLx4sVYv3599+va2lrMnDkTu3btwsyZM1Fbm5TN0okoRuATxYwZM5CTk9P9ur6+HlVVVQCAqqoqrF692q/QiAaMwCeKntrb2xEKhQAAoVAIBw8e9Dkiov4v7RJFQUEB2traAABtbW3Iz8/3OSKi/i/tEsWcOXOwcuVKAMDKlSsxd+5cnyMi6v8CP+qxaNEiNDQ0oKOjAytWrEBNTQ0WLFiAFStWoLi4GD//+c/9DpGo33NNFCLyBICvADioquOtshwAzwIoAdACYIGqHhURAfAwgEoAxwEsVtWtfQlw1apVnyrbuHFjXw7ZLdGLxnqd7Vf7h6tsy/PbWoxtOi4JGesKHRaGdeJlaM9p9uggMQ+eeul7p0Vt/z5nr7HuZz+/xFh34dcO2Z/r+yOMbRp/bJ5BPCXb/ncO0kxQL3oT/ZMAev4l1wDYqKplADZarwHgywDKrJ9qAI8mJkwi8pNrolDVTQCO9CieC2Cl9XwlgGtiyp/SqC0AzhcR8//6iCgteL0eKlDVNuv5nwAUWM9HANgX8779VhkRpbE+f3BSVQUclm8yEJFqEYmISOTQIfvPiUQUDF4TRXvXRwrrsetbT60AimLeV2iVfYqq1qlqWFXDeXnm7+ITkf+8Do+uAVAFoNZ6rI8pv11EngEwDcB7MR9RAsfpTrSXbQid7uQfPH3cWHf+I0PtKzrNF2o1M//bWOfEy2iDUxunkQ0nTmuIwnA+r9v1bQ0/a6ybMt9+zcy8J8xbNt7xz+atHrd8/z9sy72sLeq1b5OhN8OjqwBcDiBXRPYDuBvRBPGciCwF8A6ABdbb1yE6NLob0eHRryUhZiJKMddEoaqLDFUzbd6rAG7ra1BEFCzp/S0QIkoJJgoicsVEQUSumCiIyFXgZ48mU6KHCp2GTjccLzHWZb/8mm25njYPqV1z7i5j3Wn9rLHOi2RMaEr00J/XyWlr/8l+fdGvHl1mbHPufvPkNNOWk05rrZo4/a2ZaPzffewVXlEQkSsmCiJyxURBRK6YKIjIFRMFEblioiAiVwN6eNSJl9mjTjLFPGwlg+z/GVq/Hja2GZZhnt3oNcZE/86moULAebjQ1C7LYQNDr8OtIcP6om0zzG3Kvr7dWDfz9etsyzdOMM9gNfWFU7+b+ig5g6O8oiCiXmCiICJXTBRE5IqJgohcMVEQkSsmCiJyFfjh0SVLlmDt2rXIz8/Hjh07sHz5cjz++OPoWrn7vvvuQ2VlpbG9QhO6UK73LfTEWNf50Uf25Q4TDp1i9xpjomeJOg1nOvEy0zLRzi1631O78263H6DcscE8cFmRbf/v5fRvZepb819Z3wT+imLx4sVYv379WWV33nknmpub0dzc7JgkiCgxAp8oZsyYgZycnLPKHnnkEVRUVGDJkiU4evSoT5ERDRyBTxQ93XLLLdizZw+am5sRCoWwbJn9AiN1dXUIh8PoOBz/4h9EdLa0SxQFBQXIzMxERkYGbrrpJjQ2Ntq+r7q6GpFIBLnD0u5XJAqctPuvqK3tzMZjzz//PMaPH+9jNEQDQ+BHPRYtWoSGhgZ0dHRgxYoVaGhoQHNzM0QEJSUleOyxx/wOkajfC3yiWLVq1Vmvly5dGld7gXhaKNfE+16b5os3yYz/mF5nZh7vNC8Ma9rbM9XDrabzOQ8/e1v02GTbJU8b66YtNW+GN2yF/UfhP5wIGdtMGXy494H5JO0+ehBR6jFREJErJgoicsVEQUSumCiIyFXgRz2SycvdcK8jADmZx4x1km0/2lB0f8TY5oNbzaMXWRnmOEwjG068jvQ49ZVTnZcYE71O6Cd6ytjm5DnmqVeSYV83NMN+4p9TDEHCKwoicsVEQUSumCiIyJVrohCRJ0TkoIjsiClbLiKtItJs/VTG1H1bRHaLyFsi8qVkBU5EqdObK4onAVxlU/5DVZ1k/awDABEpB7AQwDirzY9FPN4JI6LAcE0UqroJwJFeHm8ugGdU9RNVfRvAbgBT+xAfEQVAX4ZHbxeRGwFEACxT1aMARgDYEvOe/VbZp4hINYBqACguLu5DGM6c1sx0Yhpu8zoxac6Q48a6e7860bY85xevGdus/XCksW7xeQeNdU4S2U+Ac195GXJN9MQvp3bPHzNP4rrwR78z1p344sW25Z/P/q1DDEOMdfGSJK2a6fVm5qMALgIwCUAbgAfiPYCq1qlqWFXDXQvlElEweUoUqtquqqdVtRPA4zjz8aIVQFHMWwutMiJKY54ShYjEXpfNA9A1IrIGwEIRGSwiIwGUAbCfoE9EacP1HoWIrAJwOYBcEdkP4G4Al4vIJER3WW8BcDMAqOpOEXkOwBsATgG4TdXhu7pElBZcE4WqLrIpXuHw/nsB3NuXoIgoWPjNTCJyxdmjBokeKnQ63v+NsS8//8MPjW3+teEaY91fXf2osS7DYfjMyxCj17U7vQx1et1G0cmvPrIfmnzwgQXGNnmZTca6/VfY/84XZX02vsAChlcUROSKiYKIXDFREJErJgoicsVEQUSumCiIyFXgh0fb2trQ1taGyZMnY/To0Vi9ejWefPJJ5OTkoKamBrW1taipqUn4eb3ORjTphBrrps983ba8dbl5kdnye8xTaKZceL2xbkt4pbHuHIl/UdsseFtuxEv/tp0yL1D8zinz8ONfrrvVWDf6Sfsh6IJ9e4xt3vx3+9m+APB65Q9tyzPlM8Y2XrZQTLXAJ4pQKIRQKDq1ZOzYsWhtbUV9fT0aGhoAAFVVVT5GRzQwpNVHj23btmHatGlob2/vTh5dj0SUPGmTKI4dO4aHHnoI5513Xq/eX1dXh3A4jI7Dwd8zgSjo0iJRnDx5EvPnz8e1114LACgoKEBbWxsAdD/2VF1djUgkgtxhafErEgVa4P8rUlUsXboUY8eO7S6bM2cOVq6M3pTreiSi5BFV8934VAmHwxqJ2G+ft3nzZnzhC1/AhAkTkJGRgfvuuw/Tpk3DggUL8O6776K4uBgbN240H3viZ9D4UpFtnZeJSU5tnEY2vNzBvujZvzHWld211ViXOeJCY90b3y4w1v3NXzTYll/82RZjm+0fm9c7Pdlpvlf+X0/PNNaJYX5X8dN7jW1OH+ow1iHT3Pcyxn7t0ZErWoxtfjT8f8znMoXgMMpjmljnNKnOZOqX9iGy/WPH94hIk6qG4zlu4Ec9pk+fDrtk5pQciCixAv/Rg4j8x0RBRK6YKIjIFRMFEblioiAiV4Ef9UgmpyGr450nbMsHi7nLnIZAvawtuee6/zC2Kc2+2Vj3+YcOGevG3GHepvDlzgtsyxvGTDG26dz5lrFOsszDe4WwHw530ll+kbGufV6psW7MX/3BWPfdwp/YlhdnnWNsk+gJg14n1qUSryiIyBUTBRG5YqIgIldMFETkiomCiFwN6FEPpwle52TEvyyc025VXib4OMW3d95jxrrvzSgz1v3nL2cZ60of2GFb/l75+cY2HQsvNdZlnDDvSpZ3mf3yAABwa0mDbXlp9qvGNlMHx9+/AHBS7ZfQ8zqyYfo3czqel3OZ/tbUYWJiX/CKgohcMVEQkSsmCiJyxURBRK6YKIjIFRMFEblyHR4VkSIATwEoAKAA6lT1YRHJAfAsgBIALQAWqOpREREADwOoBHAcwGJVNS/w6KNET+5J9M5OXuP71rBd5rqbzHW4yVSxyVMciedtCNTxiAH5N4uXKW6BeUi6L3rzW50CsExVywFcCuA2ESkHUANgo6qWAdhovQaALwMos36qATya8KiJKKVcE4WqtnVdEajqBwDeBDACwFwAXWvlrwRwjfV8LoCnNGoLgPNFhNt5EaWxuK6TRKQEwMUAfgegQFW7vl73J0Q/mgDRJLIvptl+q6znsapFJCIikUOHzOsnEJH/ep0oRORcAL8A8Heq+n5snUbX04/ru6OqWqeqYVUN5+XlxdOUiFKsV4lCRAYhmiSeVtVfWsXtXR8prMeDVnkrgNgddwqtMiJKU66JwhrFWAHgTVV9MKZqDYAq63kVgPqY8hsl6lIA78V8RCGiNNSbK4rLANwA4EoRabZ+KgHUApglIrsAfNF6DQDrAOwFsBvA4wBu7UuA+/btwxVXXIGxY8fi4YcfBgAsX74cI0aMwKRJkzBp0qS+HJ6IesH1exSquhkwDs5+agNJ637FbX2Mq1tWVhYeeOABTJ48GaNHj8asWdFp0nfeeSe+8Y1vJOo0ROQg8OtRhEIhhELR0dWxY8eitZW3O4hSLa2+wr1t2zZMmzYNAPDII4+goqICS5YssX1vXV0dwuEwDh02LyZDRL2TNoni2LFjeOihh3DeeefhlltuwZ49e9Dc3Nx9tdFTdXU1IpEI8oYFf88EoqAL/EcPADh58iTmz5+Pl156CQBQUFDQXXfTTcYJCkSUIIG/oti8eTOys7PR1taGSZMmYd26dbjhhhswYcIEVFRU4I477vA7RKJ+L/BXFNOnT0d0IOWMyspKn6IhGpgCf0VBRP5joiAiV0wUROSKiYKIXDFREJErJgoicsVEQUSumCiIyBUTBRG5YqIgIldMFETkiomCiFwFflJYX71zYCimfuXPul93dHQgNzfXx4iCFYffgtIP/SWOdw78XwKjOaPfJ4qemwuFw2FEIhGfogleHH4LSj8wDmf86EFErpgoiMjVgEsU1dXVfocAIDhx+C0o/cA4nEnP1aP8EA6HNYify4j6IxFpUtVwPG0GzBXF+vXrMWbMGIwaNQq1tbXuDZKkpKQEEyZMSPkOZ0uWLEF+fj7Gjx/fXTZr1iyUlZVh1qxZOHr0qC9x9Nz1bd26dSmJo2sHunHjxnXvQHfkyJGz+sSPGGL7I1V90Suq6vvPlClTNJlOnTqlpaWlumfPHv3kk0+0oqIiqedz8rnPfU4PHTqU8vP+9re/1aamJh03blx32Xe/+93ux7vuusuXOO6++279wQ9+kJJzxzpw4IA2NTXp+++/r2VlZbpz50795je/eVaf+BFDKvoDQETj/G90QFxRNDY2YtSoUSgtLUV2djYWLlzod0gpN2PGDOTk5JxVVlVV1f24evVq3+LwQygUwuTJkzF06NDuHejq6+vP6hM/YgiqAZEoWltbUVRU1P26sLDQt1hEBLNnz8aUKVN8i6FL1+ZJoVAIBw8e9C2O2F3fUvURqEtLS0v3DnTt7e1n9YkfMQBn+iPVfeFkQCQKDcAN2y6vvvoqtm7dihdffBGbNm3yOxzf9dz1bdmyZSk9//z587t3oPNLbAyx/ZHqvnAyIBLFvHnz8Jvf/AZvv/02Tpw4gfvvv9+XOD788EMMHToUADBkyBAcP37clzi6dN3Ura2t9W3Htc7OTmRmZiIjIwO5ubkp6xNVxY033oimpiZce+21AKK7zsX2iR8xxPaH338fZ4n3pkYyfpJ9M1NV9YUXXtCysjItLS3Ve+65J+nns7Nnzx6tqKjQiooKLS8vT+m5Fy5cqBdeeKFmZWXpiBEj9Cc/+YleeeWVOmrUKL3yyiv18OHDvsRx/fXX6/jx43XChAl69dVX64EDB1ISxyuvvKIAdOLEiTpx4kR94YUXtKOj46w+8SOG2P5IVl/Aw81Mfo+CaIDh9yiIKCmYKIjIlWuiEJEiEXlZRN4QkZ0i8nWrfLmItIpIs/VTGdPm2yKyW0TeEpEvJfMXIKLk6816FKcALFPVrSIyFECTiGyw6n6oqmcNIYhIOYCFAMYBGA7g1yIyWlVPJzJwIkod1ysKVW1T1a3W8w8AvAlghEOTuQCeUdVPVPVtALsBTE1EsETkj7juUYhICYCLAfzOKrpdRF4TkSdE5AKrbASAfTHN9sM5sRBRwPU6UYjIuQB+AeDvVPV9AI8CuAjAJABtAB6I58QiUi0iERGJ9FyujoiCpVdrZorIIESTxNOq+ksAUNX2mPrHAay1XrYCKIppXmiVnUVV6wDUWe0PiciHADo8/A5+yAVjTQbGmhw9Y/1cvAdwTRQiIgBWAHhTVR+MKQ+papv1ch6AHdbzNQD+S0QeRPRmZhmARqdzqGqeiETi/RKIXxhrcjDW5EhErL25orgMwA0AXheRZqvsHwAsEpFJABRAC4CbAUBVd4rIcwDeQHTE5DaOeBClN9dEoaqbAYhNlXH5HVW9F8C9fYiLiAIkSN/MrPM7gDgw1uRgrMnR51gDMSmMiIItSFcURBRQvicKEbnKmhOyW0Rq/I6nJxFpEZHXrfksEassR0Q2iMgu6/ECt+MkMb4nROSgiOyIKbONT6L+zerr10RkcgBiDeScIYc5ToHr25TMx4p3AYtE/gDIBLAHQCmAbADbAZT7GZNNjC0AcnuUfR9AjfW8BsD3fIxvBoDJAHa4xQegEsCLiN6cvhTA7wIQ63IA37B5b7n19zAYwEjr7yQzhbGGAEy2ng8F8EcrpsD1rUOsCetbv68opgLYrap7VfUEgGcQnSsSdHMBrLSerwRwjV+BqOomAEd6FJvimwvgKY3aAuB8EUnZKrKGWE18nTOk5jlOgetbh1hN4u5bvxNFOswLUQC/EpEmEena761Az3zZ7E8ACvwJzcgUX1D7O9BzhnrMcQp03yZrPpbfiSIdTFfVyQC+DOA2EZkRW6nRa7nADh0FPT70cc5QstnMceoWtL5N9HysWH4nil7NC/GTqrZajwcBPI/oJVp712Wl9ejfphj2TPEFrr9VtV1VT6tqJ4DHceYS2PdY7eY4IaB9a5qPlai+9TtR/B5AmYiMFJFsRBe8WeNzTN1EZIhEF+uBiAwBMBvROS1rAHRtJVUFoN6fCI1M8a0BcKN1h/5SAO/FXEb7osfn+J5zhhaKyGARGYlezBlKcFy2c5wQwL41xZrQvk3VnVmHO7aViN6l3QPgH/2Op0dspYjeHd4OYGdXfACGAdgIYBeAXwPI8THGVYheVp5E9LPmUlN8iN6R/3err18HEA5ArD+1YnnN+gMOxbz/H61Y3wLw5RTHOh3RjxWvAWi2fiqD2LcOsSasb/nNTCJy5fdHDyJKA0wUROSKiYKIXDFREJErJgoicsVEQUSumCiIyBUTBRG5+n8Pl/dQQS97/wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 导入图像读取第三方库\n",
    "import matplotlib.image as mpimg\n",
    "import matplotlib.pyplot as plt\n",
    "# 读取图像\n",
    "example = mpimg.imread('./work/example_0.png')\n",
    "# 显示图像\n",
    "plt.imshow(example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[255 255 255 ... 255 255 255]\n",
      " [255 255 255 ... 255 255 255]\n",
      " [255 255 255 ... 255 255 255]\n",
      " ...\n",
      " [255 255 255 ... 255 255 255]\n",
      " [255 255 255 ... 255 255 255]\n",
      " [255 255 255 ... 255 255 255]]\n",
      "本次预测的数字是 [[4]]\n"
     ]
    }
   ],
   "source": [
    "# 读取一张本地的样例图片，转变成模型输入的格式\n",
    "def load_image(img_path):\n",
    "    # 从img_path中读取图像，并转为灰度图\n",
    "    im = Image.open(img_path).convert('L')\n",
    "    print(np.array(im))\n",
    "    im = im.resize((28, 28), Image.ANTIALIAS)\n",
    "    im = np.array(im).reshape(1, -1).astype(np.float32)\n",
    "    # 图像归一化，保持和数据集的数据范围一致\n",
    "    im = 1 - im / 127.5  # 图像的范围是0~255，归一化的范围是-1~0， 所以是1-im/127.5\n",
    "    return im\n",
    "\n",
    "# 定义预测过程\n",
    "with fluid.dygraph.guard():\n",
    "    model = MNIST(\"mnist\")\n",
    "    params_file_path = 'mnist'\n",
    "    img_path = './work/example_0.png'\n",
    "# 加载模型参数\n",
    "    model_dict, _ = fluid.load_dygraph(\"mnist\")\n",
    "    model.load_dict(model_dict)\n",
    "# 灌入数据\n",
    "    model.eval()\n",
    "    tensor_img = load_image(img_path)\n",
    "    result = model(fluid.dygraph.to_variable(tensor_img))\n",
    "#  预测输出取整，即为预测的数字，打印结果\n",
    "    print(\"本次预测的数字是\", result.numpy().astype('int32'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "从打印结果来看，模型预测出的数字是与实际输出的图片的数字不一致。这里只是验证了一个样本的情况，如果我们尝试更多的样本，可发现许多数字图片识别结果是错误的。因此完全复用房价预测的实验并不适用于手写数字识别任务！\n",
    "\n",
    "接下来我们会对手写数字识别实验模型进行逐一改进，直到获得令人满意的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 作业 2-1：\n",
    "\n",
    "1. 使用飞桨API [paddle.dataset.mnist](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/data/dataset_cn.html)的test函数获得测试集数据，计算当前模型的准确率。\n",
    "\n",
    "2. 怎样进一步提高模型的准确率？可以再接下来内容开始前，写出你想到的优化思路。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.7.0 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
